# -*-Python-*-
# ZMidiAutoencoder

import ddsp
import ddsp.training

include 'eval/midi_ae.gin'
include 'models/midiae/mixins/recon_lossgroup.gin'

# Training
train.num_steps = 500000


# ==============================================================================
# Model
# ==============================================================================
get_model.model = @models.ZMidiAutoencoder()


# Preprocessor
ZMidiAutoencoder.preprocessor = @preprocessing.F0LoudnessPreprocessor()
F0LoudnessPreprocessor.time_steps = 1000


# Synth encoder
ZMidiAutoencoder.z_synth_encoders = [
    @encoders.OneHotEncoder(),
    @encoders.MfccTimeDistributedRnnEncoder(),
]
MfccTimeDistributedRnnEncoder.rnn_channels = 512
MfccTimeDistributedRnnEncoder.rnn_type = 'gru'
MfccTimeDistributedRnnEncoder.z_dims = 128
MfccTimeDistributedRnnEncoder.z_time_steps = 1000


# OneHotEncoder
OneHotEncoder.one_hot_key = 'instrument_id'
OneHotEncoder.vocab_size = 13  # num instruments
OneHotEncoder.n_dims = 128
OneHotEncoder.skip_expand = False


# Synthcoder
ZMidiAutoencoder.synthcoder = @decoders.DilatedConvDecoder()
DilatedConvDecoder.ch = 128
DilatedConvDecoder.layers_per_stack = 9
DilatedConvDecoder.norm_type = 'layer'
DilatedConvDecoder.input_keys = ('ld_scaled', 'f0_scaled')
DilatedConvDecoder.stacks = 2
DilatedConvDecoder.conditioning_keys = ('z',)
DilatedConvDecoder.precondition_stack = None
DilatedConvDecoder.output_splits = (('amplitudes', 1),
                                    ('harmonic_distribution', 60),
                                    ('magnitudes', 65))


# Stop Gradient
ZMidiAutoencoder.sg_before_midiae = True


# MIDI Encoder
ZMidiAutoencoder.midi_encoder = None


# MIDI Global Latents
ZMidiAutoencoder.z_global_encoders = [
    @encoders.OneHotEncoder(),
    @ee/encoders.ExpressionEncoder(),
]


# Global Expression Encoder
ee/ExpressionEncoder.input_keys = ('f0_scaled',
                                   'amps_scaled',
                                   'hd_scaled',
                                   'noise_scaled',)
ee/ExpressionEncoder.z_dims = 128
ee/ExpressionEncoder.pool_time = True
ee/ExpressionEncoder.net = @z/nn.DilatedConvStack()


# Z Note Encoder
ZMidiAutoencoder.z_note_encoder = @zn/encoders.ExpressionEncoder()
zn/ExpressionEncoder.input_keys = ('f0_scaled',
                                   'amps_scaled',
                                   'hd_scaled',
                                   'noise_scaled',)
zn/ExpressionEncoder.z_dims = 128
zn/ExpressionEncoder.pool_time = False
zn/ExpressionEncoder.net = @z/nn.DilatedConvStack()


# Same DilatedConvStack for Note encoder and Global Expression Encoder.
z/DilatedConvStack.ch = 128
z/DilatedConvStack.layers_per_stack = 5
z/DilatedConvStack.stacks = 4
z/DilatedConvStack.norm_type = 'layer'
z/DilatedConvStack.conditional = False


# Z Preconditioning Stack
ZMidiAutoencoder.z_preconditioning_stack = @nn.FcStackOut()
nn.FcStackOut.ch = 512
nn.FcStackOut.n_out = 256
nn.FcStackOut.layers = 5


# MIDI Decoder
ZMidiAutoencoder.midi_decoder = @decoders.MidiToHarmonicDecoder()
MidiToHarmonicDecoder.f0_residual = True
MidiToHarmonicDecoder.norm = True
MidiToHarmonicDecoder.output_splits = (('f0_midi', 1),
                                       ('amplitudes', 1),
                                       ('harmonic_distribution', 60),
                                       ('magnitudes', 65))
MidiToHarmonicDecoder.net = @dec/nn.DilatedConvStack()
dec/DilatedConvStack.ch = 128
dec/DilatedConvStack.layers_per_stack = 5
dec/DilatedConvStack.stacks = 4
dec/DilatedConvStack.norm_type = 'layer'
dec/DilatedConvStack.conditional = True


# ==============================================================================
# Losses
# ==============================================================================
# Reconstruction Loss
ZMidiAutoencoder.reconstruction_losses = @recon_lossgroup/losses.LossGroup()


# ZMidiAutoencoder.z_global_prior = @zg/losses.GaussianPrior()
# zg/GaussianPrior.weight = 1.0
# zg/GaussianPrior.name = 'global_kl'


# ZMidiAutoencoder.z_note_prior = @zn/losses.GaussianPrior()
# zn/GaussianPrior.weight = 1.0
# zn/GaussianPrior.name = 'note_kl'
